/**
 * Copyright (c) 2021, Ouster, Inc.
 * All rights reserved.
 *
 * @file
 * @brief ouster_pyclient_pcap python module
 */
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>

#include <chrono>
#include <cmath>
#include <sstream>
#include <string>

#include "common.h"
#include "ouster/impl/build.h"
#include "ouster/indexed_pcap_reader.h"
#include "ouster/os_pcap.h"
#include "ouster/pcap.h"

using namespace ouster::sensor_utils;
namespace py = pybind11;

PYBIND11_MAKE_OPAQUE(std::shared_ptr<playback_handle>);
PYBIND11_MAKE_OPAQUE(std::shared_ptr<record_handle>);
PYBIND11_MAKE_OPAQUE(std::vector<stream_key>);
PYBIND11_MAKE_OPAQUE(std::vector<guessed_ports>);
using stream_map = std::unordered_map<stream_key, stream_data>;
PYBIND11_MAKE_OPAQUE(stream_map);
using count_map = std::map<uint64_t, uint64_t>;
void init_pcap(py::module& m, py::module& root_m) {
    m.doc() = R"(Pcap bindings generated by pybind11.

This module is generated from the C++ code and not meant to be used directly.
)";

    // turn off signatures in docstrings: mypy stubs provide better types
    py::options options;
    options.disable_function_signatures();
    py::bind_vector<std::vector<stream_key>>(root_m, "VectorStreamKey");
    py::bind_vector<std::vector<int>>(root_m, "VectorInt");
    py::bind_vector<std::vector<guessed_ports>>(root_m, "VectorGuessedPorts");
    py::bind_vector<std::vector<uint8_t>>(root_m, "VectorUint8");
    py::bind_map<std::unordered_map<stream_key, stream_data>>(root_m,
                                                              "MapUdpStreams");
    py::bind_map<std::map<uint64_t, uint64_t>>(root_m, "CountMap");
    py::class_<packet_info, std::shared_ptr<packet_info>>(m, "packet_info")
        .def(py::init<>())
        .def("__repr__",
             [](const packet_info& data) {
                 std::stringstream result;
                 result << data;
                 return result.str();
             })
        .def_readwrite("dst_ip", &packet_info::dst_ip)
        .def_readwrite("src_ip", &packet_info::src_ip)
        .def_readwrite("dst_port", &packet_info::dst_port)
        .def_readwrite("src_port", &packet_info::src_port)
        .def_readonly("payload_size", &packet_info::payload_size)
        .def_property(
            "timestamp",
            [](packet_info& packet_info) -> double {
                return packet_info.timestamp.count() / 1e6;
            },
            [](packet_info& packet_info, double set) {
                std::chrono::microseconds msec{(int)(set * 1e6)};
                packet_info.timestamp = msec;
            })
        .def_readonly("file_offset", &packet_info::file_offset)
        .def_readonly("fragments_in_packet", &packet_info::fragments_in_packet)
        .def_readonly("ip_version", &packet_info::ip_version)
        .def_readonly("encapsulation_protocol",
                      &packet_info::encapsulation_protocol)
        .def_readonly("network_protocol", &packet_info::network_protocol);

    py::class_<guessed_ports, std::shared_ptr<guessed_ports>>(m,
                                                              "guessed_ports")
        .def(py::init<>())
        .def_readonly("lidar", &guessed_ports::lidar)
        .def_readonly("imu", &guessed_ports::imu);

    py::class_<stream_key, std::shared_ptr<stream_key>>(m, "stream_key")
        .def(py::init<>())
        .def("__repr__",
             [](const stream_key& data) {
                 std::stringstream result;
                 result << data;
                 return result.str();
             })
        .def_readonly("dst_ip", &stream_key::dst_ip)
        .def_readonly("src_ip", &stream_key::src_ip)
        .def_readonly("dst_port", &stream_key::dst_port)
        .def_readonly("src_port", &stream_key::src_port);

    py::class_<stream_data, std::shared_ptr<stream_data>>(m, "stream_data")
        .def(py::init<>())
        .def("__repr__",
             [](const stream_data& data) {
                 std::stringstream result;
                 result << data;
                 return result.str();
             })
        .def_readonly("count", &stream_data::count)
        .def_readonly("payload_size_counts", &stream_data::payload_size_counts)
        .def_readonly("fragment_counts", &stream_data::fragment_counts)
        .def_readonly("ip_version_counts", &stream_data::ip_version_counts);

    py::class_<stream_info, std::shared_ptr<stream_info>>(m, "stream_info")
        .def(py::init<>())
        .def("__repr__",
             [](const stream_info& data) {
                 std::stringstream result;
                 result << data;
                 return result.str();
             })
        .def_readonly("total_packets", &stream_info::total_packets)
        .def_readonly("encapsulation_protocol",
                      &stream_info::encapsulation_protocol)
        .def_readonly("timestamp_max", &stream_info::timestamp_max)
        .def_readonly("timestamp_min", &stream_info::timestamp_min)
        .def_readonly("udp_streams", &stream_info::udp_streams)
        .def_property_readonly("timestamp_max",
                               [](stream_info& info) -> double {
                                   return info.timestamp_max.count() / 1e6;
                               })
        .def_property_readonly("timestamp_min",
                               [](stream_info& info) -> double {
                                   return info.timestamp_min.count() / 1e6;
                               });
    // pcap reading
    py::class_<std::shared_ptr<playback_handle>>(m, "playback_handle");

    m.def("replay_initialize", &replay_initialize);

    m.def("replay_uninitialize", [](std::shared_ptr<playback_handle>& handle) {
        replay_uninitialize(*handle);
    });

    m.def("next_packet_info",
          [](std::shared_ptr<playback_handle>& handle, packet_info& packet_info)
              -> bool { return next_packet_info(*handle, packet_info); });

    m.def("get_stream_info",
          [](const std::string& file, int packets_to_process = -1) {
              return get_stream_info(file, packets_to_process);
          });

    m.def(
        "get_stream_info",
        [](const std::string& file,
           std::function<void(uint64_t, uint64_t, uint64_t)> progress_callback,
           int packets_per_callback, int packets_to_process = -1) {
            return get_stream_info(file, progress_callback,
                                   packets_per_callback, packets_to_process);
        });
    m.def("guess_ports", &guess_ports);
    m.def(
        "read_packet",
        [](std::shared_ptr<playback_handle>& handle, py::buffer buf) -> size_t {
            auto info = buf.request();
            if (info.format != py::format_descriptor<uint8_t>::format()) {
                throw std::invalid_argument(
                    "Incompatible argument: expected a bytearray");
            }
            return read_packet(*handle, static_cast<uint8_t*>(info.ptr),
                               info.size);
        });

    m.def("replay_reset", [](std::shared_ptr<playback_handle>& handle) {
        replay_reset(*handle);
    });

    // pcap writing
    py::class_<std::shared_ptr<record_handle>>(m, "record_handle");

    m.def("record_initialize",
          py::overload_cast<const std::string&, int, bool>(&record_initialize),
          py::arg("file_name"), py::arg("frag_size"),
          py::arg("use_sll_encapsulation") = false,
          R"(
                ``def record_initialize(file_name: str, frag_size: int,
                      use_sll_encapsulation: bool = ...) -> record_handle:``
                  
                  Initialize record handle for single sensor pcap files

            )");

    m.def("record_uninitialize", [](std::shared_ptr<record_handle>& handle) {
        record_uninitialize(*handle);
    });

    m.def("record_packet",
          [](std::shared_ptr<record_handle>& handle, const std::string& src_ip,
             const std::string& dst_ip, int src_port, int dst_port,
             py::buffer buf, double timestamp) {
              auto info = buf.request();
              if (info.format != py::format_descriptor<uint8_t>::format()) {
                  throw std::invalid_argument(
                      "Incompatible argument: expected a bytearray");
              }
              record_packet(*handle, src_ip, dst_ip, src_port, dst_port,
                            static_cast<uint8_t*>(info.ptr), info.size,
                            llround(timestamp * 1e6));
          });

    m.attr("__version__") = ouster::SDK_VERSION;

    m.def("record_packet", [](std::shared_ptr<record_handle>& handle,
                              const packet_info& info, py::buffer buf) {
        auto buf_info = buf.request();
        if (buf_info.format != py::format_descriptor<uint8_t>::format()) {
            throw std::invalid_argument(
                "Incompatible argument: expected a bytearray");
        }
        record_packet(*handle, info, static_cast<uint8_t*>(buf_info.ptr),
                      buf_info.size);
    });

    py::class_<PcapReader>(m, "PcapReader");  // TODO add more complete bindings

    py::class_<PcapIndex>(m, "PcapIndex")
        .def(py::init<int>())
        .def("frame_count", &PcapIndex::frame_count)
        .def("seek_to_frame", &PcapIndex::seek_to_frame)
        .def_property_readonly(
            "frame_indices",
            [](PcapIndex& self) {
                std::vector<py::array> l;
                for (auto& i : self.frame_indices_) {
                    l.push_back(py::array(py::dtype::of<uint64_t>(), i.size(),
                                          i.data(), py::cast(self)));
                }
                return l;
            })
        .def_readonly("frame_timestamp_indices",
                      &PcapIndex::frame_timestamp_indices_)
        .def_readonly("frame_id_indices", &PcapIndex::frame_id_indices_);

    py::class_<IndexedPcapReader, PcapReader>(m, "IndexedPcapReader")
        .def(py::init<const std::string&, const std::vector<std::string>&>())
        .def(py::init<const std::string&,
                      const std::vector<ouster::sensor::sensor_info>&>())
        .def("current_info", &IndexedPcapReader::current_info)
        .def("next_packet", &IndexedPcapReader::next_packet)
        .def("update_index_for_current_packet",
             &IndexedPcapReader::update_index_for_current_packet)
        .def("current_frame_id",
             [](IndexedPcapReader& reader) -> py::object {
                 if (auto frame_id = reader.current_frame_id()) {
                     return py::int_(*frame_id);
                 }
                 return py::none();
             })
        .def("reset",
             &IndexedPcapReader::reset)  // TODO move to PcapReader binding?
        .def("seek",
             &IndexedPcapReader::seek)  // TODO move to PcapReader binding?
        .def("build_index", &IndexedPcapReader::build_index)
        .def("get_index", &IndexedPcapReader::get_index)
        .def("current_data", [](IndexedPcapReader& reader) -> py::array {
            uint8_t* data = const_cast<uint8_t*>(reader.current_data());
            size_t data_size = reader.current_length();
            return py::array(py::dtype::of<uint8_t>(), data_size, data,
                             py::cast(reader));
        });
}
